package api

import (
	"errors"
	"fmt"
	"io/ioutil"
	"math/rand"
	"net/http"
	neturl "net/url"
	"time"

	"gitee.com/satyr/tools/codec/jsoniter"

	ctx "gitee.com/satyr/tools/context"

	"gitee.com/satyr/tools/common"
	"gitee.com/satyr/tools/curl"
	"gitee.com/satyr/tools/encoding"
)

//内部第三方接口返回
type Response struct {
	ErrNo   int64       `json:"err_no"`
	ErrMsg  string      `json:"err_msg"`
	Results interface{} `json:"results"`
}

//RequestUnmarshal 从httpCtx里解析数据到params
func RequestUnmarshal(httpCtx *ctx.HTTPContext, params interface{}) (err error) {
	if httpCtx == nil || httpCtx.Request == nil {
		return common.NewRespErr(500, "nil httpCtx")
	}
	if httpCtx.Request.Method != "POST" {
		return common.NewRespErr(20064001, "必须为post请求")
	}
	if httpCtx.Request.Body == nil {
		return common.NewRespErr(400, "body不能为空")
	}
	defer httpCtx.Request.Body.Close()

	body, err := ioutil.ReadAll(httpCtx.Request.Body)
	if err != nil {
		return common.NewRespErr(400, err)
	}
	httpCtx.Debugf("Request cmp Params is: %s", string(body))
	if err = jsoniter.Unmarshal(body, &params); err != nil {
		httpCtx.Error(err)
		return common.NewRespErr(400, "request params invalid")
	}

	return nil
}

//RequestBody 从httpCtx里获取body数据
func RequestBody(httpCtx *ctx.HTTPContext) (body []byte, err error) {
	if httpCtx == nil || httpCtx.Request == nil {
		return nil, common.NewRespErr(500, "nil httpCtx")
	}
	if httpCtx.Request.Method != "POST" {
		return nil, common.NewRespErr(20064001, "必须为post请求")
	}
	if httpCtx.Request.Body == nil {
		return nil, common.NewRespErr(400, "body不能为空")
	}
	defer httpCtx.Request.Body.Close()

	body, err = ioutil.ReadAll(httpCtx.Request.Body)
	if err != nil {
		return nil, common.NewRespErr(400, err)
	}

	return body, nil
}

//用于调用内部的其他标准http服务
//标准的http服务是指response里包含err_no、err_msg和results
func StdCall(httpCtx *ctx.HTTPContext, addresses []string, uri string, p interface{}, results interface{}) (err error) {

	var resp Response
	resp.Results = results

	err = Call(httpCtx, map[string]string{}, addresses, uri, p, resp)
	if err != nil {
		return err
	}

	if resp.ErrNo != 0 {
		return common.NewRespErr(resp.ErrNo, resp.ErrMsg)
	}

	return
}

//用于任意返回json数据的http服务
func Call(httpCtx *ctx.HTTPContext, header map[string]string, addresses []string, uri string, p interface{}, resp interface{}) (err error) {
	if httpCtx == nil || httpCtx.Ctx == nil {
		return common.NewRespErr(500, "nil httpCtx")
	}
	req, err := encoding.JSON.Marshal(p)
	if err != nil {
		return common.NewRespErr(500, err)
	}
	httpCtx.Debugf("call api req is %s", req)
	return underCall(httpCtx, header, addresses, uri, req, resp)
}

//用于透传需要任意返回json数据的http服务
func TransCall(httpCtx *ctx.HTTPContext, header map[string]string, addresses []string, uri string, resp interface{}) (err error) {
	req, err := RequestBody(httpCtx)
	if err != nil {
		return err
	}

	return underCall(httpCtx, header, addresses, uri, req, resp)
}

func underCall(httpCtx *ctx.HTTPContext, header map[string]string, addresses []string, uri string, req []byte, resp interface{}) (err error) {
	if httpCtx == nil || httpCtx.Ctx == nil {
		return common.NewRespErr(500, "nil httpCtx")
	}

	c := curl.NewPost(httpCtx.Ctx, "")
	c.Headers.Set("Content-Type", "application/json")
	c.Headers.Set("Trace-Id", httpCtx.GetTraceID())
	if appId := httpCtx.Ctx.Value("app_id"); appId != nil {
		c.Headers.Set("AppID", appId.(string))
	}
	for k, v := range header {
		c.Headers.Set(k, v)
	}
	c.SetTimeout(30)
	c.PostBytes = req

	defer func(t time.Time) {
		httpCtx.Infof("Call: %v %s CostTime: %v", addresses, uri, time.Since(t))
	}(time.Now())

	var rs *curl.Response
FOR:
	for i := 0; i < 3; i++ {
		select {
		case <-httpCtx.Ctx.Done():
			return httpCtx.Ctx.Err()
		default:
			c.Url, err = getApiUrl(addresses, uri)
			if err != nil {
				return common.NewRespErr(500, err)
			}
			c.SetContext(httpCtx.Ctx)
			rs, err = c.Request()
			if err != nil {
				httpCtx.Warnf("Url: [%s] %s", c.Url, err.Error())
				// if err == curl.ErrRequestTimeout {
				// 	break FOR
				// }
				continue FOR
			}
			defer rs.Close()

			break FOR
		}
	}

	if err != nil {
		return common.NewRespErr(500, err)
	}

	if rs.StatusCode != http.StatusOK {
		return fmt.Errorf("Call %s get http status code: %d status: %s",
			c.Url, rs.StatusCode, rs.Status)
	}

	if err = encoding.JSONIO.Unmarshal(rs.Body, &resp); err != nil {
		return common.NewRespErr(500, err)
	}

	return nil
}

func CallRaw(httpCtx *ctx.HTTPContext, header map[string]string, addresses []string, uri string, p interface{}) (data []byte, err error) {
	if httpCtx == nil || httpCtx.Ctx == nil {
		return nil, common.NewRespErr(500, "nil httpCtx")
	}
	c := curl.NewPost(httpCtx.Ctx, "")
	c.Headers.Set("Content-Type", "application/json")
	c.Headers.Set("Trace-Id", httpCtx.GetTraceID())
	if appId := httpCtx.Ctx.Value("app_id"); appId != nil {
		c.Headers.Set("AppID", appId.(string))
	}
	for k, v := range header {
		c.Headers.Set(k, v)
	}
	c.SetTimeout(30)
	c.PostBytes, err = encoding.JSON.Marshal(p)
	if err != nil {
		return nil, common.NewRespErr(500, err)
	}

	defer func(t time.Time) {
		httpCtx.Infof("Call: %v %s CostTime: %v", addresses, uri, time.Since(t))
	}(time.Now())

	var rs *curl.Response
FOR:
	for i := 0; i < 3; i++ {
		select {
		case <-httpCtx.Ctx.Done():
			return nil, httpCtx.Ctx.Err()
		default:
			c.Url, err = getApiUrl(addresses, uri)
			if err != nil {
				return nil, common.NewRespErr(500, err)
			}
			c.SetContext(httpCtx.Ctx)
			rs, err = c.Request()
			if err != nil {
				httpCtx.Warnf("Url: [%s] %s", c.Url, err.Error())
				// if err == curl.ErrRequestTimeout {
				// 	break FOR
				// }
				continue FOR
			}
			if rs == nil {
				return nil, fmt.Errorf("response is nil")
			}
			defer rs.Close()
			break FOR
		}
	}

	if err != nil {
		return nil, common.NewRespErr(500, err)
	}

	if rs.StatusCode != http.StatusOK {
		return nil, fmt.Errorf("Call %s get http status code: %d status: %s", c.Url, rs.StatusCode, rs.Status)
	}
	return ioutil.ReadAll(rs.Body)
}

func getApiUrl(addresses []string, uri string) (string, error) {
	n := len(addresses)
	var domain string
	if n == 0 {
		return "", errors.New("nil addresses")
	} else if n == 1 {
		domain = addresses[0]
	} else {
		domain = addresses[rand.New(rand.NewSource(time.Now().UnixNano())).Intn(n)]
	}
	u, err := neturl.ParseRequestURI(domain)
	if err != nil || len(u.Scheme) == 0 {
		return "", fmt.Errorf("error url: [%s]", domain)
	}

	if len(uri) > 0 {
		u, err = u.Parse(uri)
		if err != nil {
			return "", err
		}
	}

	return u.String(), nil
}
